Classification with B-Spline Basis Function Coefficients

Note

Last updated 12 AM, 8/17/2020

import pickle as pkl
import numpy as np
import pandas as pd

import torch
from torch import nn, optim
from torch.utils.data import Dataset, TensorDataset, DataLoader
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from sklearn.model_selection import KFold

import plotly.graph_objects as go
from plotly.subplots import make_subplots
import plotly.io as pio
pio.renderers.default = 'notebook'

from IPython.display import display
from ipythonblocks import BlockGrid
from webcolors import name_to_rgb
from scipy import interpolate
from sympy import lambdify, bspline_basis_set, symbols
import warnings
warnings.filterwarnings('ignore')
grid = BlockGrid(15,1,fill=(0,0,0))
grid.block_size = 50
grid.lines_on = False

colors = ['slategray','sienna','darkred','crimson','darkorange','darkgoldenrod','darkkhaki','mediumseagreen','darkgreen','darkcyan','cornflowerblue','mediumblue','blueviolet','purple','hotpink']
i = 0
for block in grid:
    color = name_to_rgb(colors[i])
    block.set_colors(color[0],color[1],color[2])
    i+=1
with open('trajectories_updated.pkl', 'rb') as f:
    data = pkl.load(f)
traj_df = data['traj_df']
mean_df = data['mean_df']
clipdata_df = data['clipdata_df']

B-Splines

From the previous section, we have shown that trajectories can be effectively smoothed with B-splines.

A key concept behind B-splines is that any B-spline can be created from a linear combination of basis functions. Additionally, given the same knot vector and spline order/degree, the B-spline basis functions will always be the same. These two properties mean that the coefficients of the basis functions are sufficient to describe a B-spline.

Let’s take a look at some trajectories and their bases. We’ll use a cubic B-spline and the same knot vector with 20 bases.

Bases

# find basis functions
clips = ['oceans','overcome']
pids = [1,2]

bases = np.empty((len(clips),len(pids)), dtype=object)
info = np.empty((len(clips),len(pids)), dtype=object)

for clip, clip_name in enumerate(clips):
    for pid_i, pid in enumerate(pids):
        temp_df = traj_df[(traj_df.clip_name==clip_name) & (traj_df.pid==pid)]
        data = temp_df[['x','y','z']].to_numpy()
        tck, u = interpolate.splprep(data.T, k=3, u=np.linspace(0,1,temp_df['clip_len'].iloc[0]), t=np.linspace(0,1,24), task=-1)
        temp_df['clip_len'] = temp_df['clip_len']-1

        u_sym = symbols('u')
        basis = bspline_basis_set(tck[2], tck[0].tolist(), u_sym)

        bases[clip][pid_i] = basis
        info[clip][pid_i] = clip_name+' '+str(pid)

        
        
# print basis functions:
bases_to_display = [0,1,2,-1,len(bases[0][0])-3,len(bases[0][0])-2,len(bases[0][0])-1]
for basis in bases_to_display:
    if (basis==-1):
        print('...')
    else:
        print(f'Basis {basis}:')
        for clip in range(len(bases)):
            for pid_i in range(len(bases[0])):
                basis_str = ''
                for piece in range(len(bases[clip][pid_i][basis].args)-1):
                    basis_str += str((bases[clip][pid_i][basis].args[piece])[1])
                    basis_str += '\n'+' '*17
                    basis_str += str((bases[clip][pid_i][basis].args[piece])[0])
                    basis_str += '\n'+' '*13
                print(f'{info[clip][pid_i]:10s} : {basis_str}')
    print()
Basis 0:
oceans 1   : (u >= 0) & (u <= 0.173913043478261)
                 -190.109375*u**3 + 99.1875*u**2 - 17.25*u + 1.0
             
oceans 2   : (u >= 0) & (u <= 0.173913043478261)
                 -190.109375*u**3 + 99.1875*u**2 - 17.25*u + 1.0
             
overcome 1 : (u >= 0) & (u <= 0.173913043478261)
                 -190.109375*u**3 + 99.1875*u**2 - 17.25*u + 1.0
             
overcome 2 : (u >= 0) & (u <= 0.173913043478261)
                 -190.109375*u**3 + 99.1875*u**2 - 17.25*u + 1.0
             

Basis 1:
oceans 1   : (u >= 0) & (u <= 0.173913043478261)
                 463.866875*u**3 - 178.5375*u**2 + 17.25*u
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 -486.68*u**3 + 317.4*u**2 - 69.0*u + 5.0
             
oceans 2   : (u >= 0) & (u <= 0.173913043478261)
                 463.866875*u**3 - 178.5375*u**2 + 17.25*u
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 -486.68*u**3 + 317.4*u**2 - 69.0*u + 5.0
             
overcome 1 : (u >= 0) & (u <= 0.173913043478261)
                 463.866875*u**3 - 178.5375*u**2 + 17.25*u
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 -486.68*u**3 + 317.4*u**2 - 69.0*u + 5.0
             
overcome 2 : (u >= 0) & (u <= 0.173913043478261)
                 463.866875*u**3 - 178.5375*u**2 + 17.25*u
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 -486.68*u**3 + 317.4*u**2 - 69.0*u + 5.0
             

Basis 2:
oceans 1   : (u >= 0) & (u <= 0.173913043478261)
                 -375.149166666667*u**3 + 79.35*u**2
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 1906.16333333333*u**3 - 1110.9*u**2 + 207.0*u - 12.0
             (u >= 0.217391304347826) & (u <= 0.260869565217391)
                 -1013.91666666667*u**3 + 793.5*u**2 - 207.0*u + 18.0
             
oceans 2   : (u >= 0) & (u <= 0.173913043478261)
                 -375.149166666667*u**3 + 79.35*u**2
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 1906.16333333333*u**3 - 1110.9*u**2 + 207.0*u - 12.0
             (u >= 0.217391304347826) & (u <= 0.260869565217391)
                 -1013.91666666667*u**3 + 793.5*u**2 - 207.0*u + 18.0
             
overcome 1 : (u >= 0) & (u <= 0.173913043478261)
                 -375.149166666667*u**3 + 79.35*u**2
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 1906.16333333333*u**3 - 1110.9*u**2 + 207.0*u - 12.0
             (u >= 0.217391304347826) & (u <= 0.260869565217391)
                 -1013.91666666667*u**3 + 793.5*u**2 - 207.0*u + 18.0
             
overcome 2 : (u >= 0) & (u <= 0.173913043478261)
                 -375.149166666667*u**3 + 79.35*u**2
             (u >= 0.173913043478261) & (u <= 0.217391304347826)
                 1906.16333333333*u**3 - 1110.9*u**2 + 207.0*u - 12.0
             (u >= 0.217391304347826) & (u <= 0.260869565217391)
                 -1013.91666666667*u**3 + 793.5*u**2 - 207.0*u + 18.0
             

...

Basis 17:
oceans 1   : (u >= 0.739130434782609) & (u <= 0.782608695652174)
                 1013.91666666667*u**3 - 2248.25*u**2 + 1661.75*u - 409.416666666667
             (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 -1906.16333333333*u**3 + 4607.58999999999*u**2 - 3703.68999999999*u + 990.263333333331
             (u >= 0.826086956521739) & (u <= 1.0)
                 375.149166666666*u**3 - 1046.0975*u**2 + 966.747499999999*u - 295.799166666667
             
oceans 2   : (u >= 0.739130434782609) & (u <= 0.782608695652174)
                 1013.91666666667*u**3 - 2248.25*u**2 + 1661.75*u - 409.416666666667
             (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 -1906.16333333333*u**3 + 4607.58999999999*u**2 - 3703.68999999999*u + 990.263333333331
             (u >= 0.826086956521739) & (u <= 1.0)
                 375.149166666666*u**3 - 1046.0975*u**2 + 966.747499999999*u - 295.799166666667
             
overcome 1 : (u >= 0.739130434782609) & (u <= 0.782608695652174)
                 1013.91666666667*u**3 - 2248.25*u**2 + 1661.75*u - 409.416666666667
             (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 -1906.16333333333*u**3 + 4607.58999999999*u**2 - 3703.68999999999*u + 990.263333333331
             (u >= 0.826086956521739) & (u <= 1.0)
                 375.149166666666*u**3 - 1046.0975*u**2 + 966.747499999999*u - 295.799166666667
             
overcome 2 : (u >= 0.739130434782609) & (u <= 0.782608695652174)
                 1013.91666666667*u**3 - 2248.25*u**2 + 1661.75*u - 409.416666666667
             (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 -1906.16333333333*u**3 + 4607.58999999999*u**2 - 3703.68999999999*u + 990.263333333331
             (u >= 0.826086956521739) & (u <= 1.0)
                 375.149166666666*u**3 - 1046.0975*u**2 + 966.747499999999*u - 295.799166666667
             

Basis 18:
oceans 1   : (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 486.679999999999*u**3 - 1142.64*u**2 + 894.239999999998*u - 233.279999999999
             (u >= 0.826086956521739) & (u <= 1.0)
                 -463.866875*u**3 + 1213.063125*u**2 - 1051.775625*u + 302.579375
             
oceans 2   : (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 486.679999999999*u**3 - 1142.64*u**2 + 894.239999999998*u - 233.279999999999
             (u >= 0.826086956521739) & (u <= 1.0)
                 -463.866875*u**3 + 1213.063125*u**2 - 1051.775625*u + 302.579375
             
overcome 1 : (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 486.679999999999*u**3 - 1142.64*u**2 + 894.239999999998*u - 233.279999999999
             (u >= 0.826086956521739) & (u <= 1.0)
                 -463.866875*u**3 + 1213.063125*u**2 - 1051.775625*u + 302.579375
             
overcome 2 : (u >= 0.782608695652174) & (u <= 0.826086956521739)
                 486.679999999999*u**3 - 1142.64*u**2 + 894.239999999998*u - 233.279999999999
             (u >= 0.826086956521739) & (u <= 1.0)
                 -463.866875*u**3 + 1213.063125*u**2 - 1051.775625*u + 302.579375
             

Basis 19:
oceans 1   : (u >= 0.826086956521739) & (u <= 1.0)
                 190.109375*u**3 - 471.140625*u**2 + 389.203125*u - 107.171875
             
oceans 2   : (u >= 0.826086956521739) & (u <= 1.0)
                 190.109375*u**3 - 471.140625*u**2 + 389.203125*u - 107.171875
             
overcome 1 : (u >= 0.826086956521739) & (u <= 1.0)
                 190.109375*u**3 - 471.140625*u**2 + 389.203125*u - 107.171875
             
overcome 2 : (u >= 0.826086956521739) & (u <= 1.0)
                 190.109375*u**3 - 471.140625*u**2 + 389.203125*u - 107.171875
             

From these examples, we can see that the bases are identical regardless of the clip or participant, which is consistent with the definition of a B-spline. We can also notice that they are sorted by \(u\), which corresponds to time in the trajectory.

Coefficients

Now we can calculate the coefficients for each trajectory. Since our trajectories are in 3-space, there will be a coefficient for each dimension (x,y,z) that corresponds to each basis. We’ll also increase the number of bases to 50.

num_coeff = 50

c_x = np.empty(0)
c_y = np.empty(0)
c_z = np.empty(0)
basis = np.empty(0,dtype=int)
for clip, clip_name in enumerate(clipdata_df['clip_name']):
    temp_df = traj_df[(traj_df.clip_name==clip_name)]
    for pid in range(max(temp_df.pid)):
        data = temp_df[temp_df.pid==pid+1][['x','y','z']].to_numpy()
        tck, u = interpolate.splprep(data.T, k=3, u=np.linspace(0,1,temp_df['clip_len'].iloc[0]), t=np.linspace(0,1,num_coeff+4), task=-1)
        c_x = np.append(c_x, tck[1][0])
        c_y = np.append(c_y, tck[1][1])
        c_z = np.append(c_z, tck[1][2])
        basis = np.append(basis, np.arange(0,num_coeff,dtype=int))
    temp_df = temp_df[temp_df.time==1].drop(columns=['time'])

coeff_df = traj_df[traj_df.time==1].drop(columns=['time'])
coeff_df = coeff_df.iloc[np.arange(len(coeff_df)).repeat(num_coeff)] # duplicate rows
coeff_df['basis'] = basis
coeff_df['x'] = c_x
coeff_df['y'] = c_y
coeff_df['z'] = c_z
coeff_df = coeff_df[['clip','clip_name','clip_len','pid','basis','x','y','z']] # reorder columns
coeff_df = coeff_df.rename(columns={'x': 'c_x', 'y': 'c_y', 'z': 'c_z'})
coeff_df = coeff_df.reset_index(drop=True)

display(coeff_df)
clip clip_name clip_len pid basis c_x c_y c_z
0 0 testretest 84 1 0 -0.352328 0.181272 -0.192012
1 0 testretest 84 1 1 0.926391 0.761404 0.996256
2 0 testretest 84 1 2 -1.239233 -0.238418 1.699006
3 0 testretest 84 1 3 -1.511278 -2.617598 1.343270
4 0 testretest 84 1 4 -4.503501 -3.981479 3.552777
... ... ... ... ... ... ... ... ...
68395 14 starwars 256 76 45 -5.129921 -15.103396 18.735031
68396 14 starwars 256 76 46 -4.398602 -9.792690 14.930604
68397 14 starwars 256 76 47 -1.137374 -23.311353 21.881870
68398 14 starwars 256 76 48 -10.233234 -12.591830 8.810135
68399 14 starwars 256 76 49 -3.503395 -14.279895 16.766243

68400 rows × 8 columns

Before classification, we can take a look at the coefficients.

fig = make_subplots(rows=3, cols=2, 
                    shared_xaxes=True,
                    vertical_spacing=0.03, horizontal_spacing=0.05,
                    subplot_titles=('x','x','y','y','z','z'), 
                    specs=[[{'type':'scatter'}, {'type':'scatter'}], [{'type':'scatter'}, {'type':'scatter'}], [{'type':'scatter'}, {'type':'scatter'}]])

for clip, clip_name in enumerate(clipdata_df['clip_name']):

    # smoothed (splines)
    temp_df = coeff_df[(coeff_df.clip_name==clip_name)]
    temp_df['mean_c_x'] = temp_df.groupby('basis')['c_x'].transform('mean')
    temp_df['std_c_x'] = temp_df.groupby('basis')['c_x'].transform('std')
    temp_df['mean_c_y'] = temp_df.groupby('basis')['c_y'].transform('mean')
    temp_df['std_c_y'] = temp_df.groupby('basis')['c_y'].transform('std')
    temp_df['mean_c_z'] = temp_df.groupby('basis')['c_z'].transform('mean')
    temp_df['std_c_z'] = temp_df.groupby('basis')['c_z'].transform('std')
    temp_df = temp_df[temp_df.pid==1]
    
#     visibility = 'legendonly'
#     if (clip_name=='oceans'):
#         visibility = True
    visibility = True

    for row, var in enumerate(['x','y','z']):
        row += 1
        mean = 'mean_c_'+var
        std = 'std_c_'+var
        if (var=='x'):
            showlegend=True
        else:
            showlegend=False
        
        # c (no std)
        mean_traj = go.Scatter(
            x=temp_df['basis'],
            y=temp_df[mean],
            customdata=temp_df[std],
            mode='markers+lines',
            line={'width':2, 'color':colors[clip]},
            marker={'size':4, 'color':colors[clip]},
            name=clip_name,
            legendgroup=clip_name,
            showlegend=showlegend,
            visible=visibility,
            hovertemplate='basis: %{x}<br>coeff: %{y:.3f}<br>sd: %{customdata:.3f}'
        )
        fig.add_trace(mean_traj, row=row, col=1)

        # c (std)
        mean_traj = go.Scatter(
            x=temp_df['basis'],
            y=temp_df[mean],
            customdata=temp_df[std],
            mode='markers+lines',
            line={'width':2, 'color':colors[clip]},
            marker={'size':4, 'color':colors[clip]},
            name=clip_name,
            legendgroup=clip_name,
            showlegend=False,
            visible=visibility,
            hovertemplate='basis: %{x}<br>coeff: %{y:.3f}<br>sd: %{customdata:.3f}'
        )
        fig.add_trace(mean_traj, row=row, col=2)

        upper = temp_df[mean] + temp_df[std]
        lower = temp_df[mean] - temp_df[std]
        std_traj = go.Scatter(
            x=np.concatenate([temp_df.index, temp_df.index[::-1]])-temp_df.index[0],
            y=pd.concat([upper, lower[::-1]]),
            fill='toself',
            mode='lines',
            line={'width':0, 'color':colors[clip]},
            opacity=0.7,
            name=clip_name,
            legendgroup=clip_name,
            showlegend=False,
            visible=visibility,
            hoverinfo='skip'
        )
        fig.add_trace(std_traj, row=row, col=2)

# formatting
fig.update_layout(
    autosize=False,
    showlegend=True,
    width=800, 
    height=1200, 
    margin={'l':0, 'r':0, 't':70, 'b':120},
    legend={'orientation':'h',
            'itemsizing':'constant',
            'xanchor':'center',
            'yanchor':'bottom',
            'x':0.5,
            'y':-0.07,
            'tracegroupgap':2},
    title={'text':'Mean Individual Trajectory B-Spline Basis Function Coefficients',
            'xanchor':'center',
            'yanchor':'top',
            'x':0.5,
            'y':0.98},
    hovermode='closest')
fig['layout']['annotations'] += (
    {'xref':'paper',
     'yref':'paper',
     'xanchor':'center',
     'yanchor':'bottom',
     'x':0.5,
     'y':-0.12,
     'showarrow':False,
     'text':'<b>Fig. 1.</b> Mean basis coefficients across all participants for each clip.<br>Error bars show the standard deviation of the mean basis coefficients.'
    },
)

plotly_config = {'displaylogo':False,
                 'modeBarButtonsToRemove': ['autoScale2d','toggleSpikelines','hoverClosestCartesian','hoverCompareCartesian','lasso2d','select2d']}

fig.show(config=plotly_config)

These coefficients look reasonably separable, so it seems that classification seems possible. Each clip’s mean coefficients are fairly distanced from other clips. The standard deviation isn’t too high, which indicates a lower variability that would make it easier to consistently predict clips. However, there is some overlap between clips.

Classification

Dataset

Before beginning classification, we want to scale our features using mean normalization:

\[x' = \dfrac{x-mean(x)}{max(x)-min(x)}\]

This restricts all features between -1 and 1 as well as setting the mean of each feature to 0.

for c in ['c_x','c_y','c_z']:
    coeff_df[c] = (coeff_df[c] - coeff_df[c].mean()) / (max(coeff_df[c]) - min(coeff_df[c]))

display(coeff_df)
clip clip_name clip_len pid basis c_x c_y c_z
0 0 testretest 84 1 0 -0.028804 0.019545 -0.050679
1 0 testretest 84 1 1 -0.006083 0.028310 -0.028942
2 0 testretest 84 1 2 -0.044562 0.013203 -0.016086
3 0 testretest 84 1 3 -0.049396 -0.022744 -0.022594
4 0 testretest 84 1 4 -0.102562 -0.043351 0.017825
... ... ... ... ... ... ... ... ...
68395 14 starwars 256 76 45 -0.113693 -0.211393 0.295556
68396 14 starwars 256 76 46 -0.100699 -0.131153 0.225961
68397 14 starwars 256 76 47 -0.042753 -0.335407 0.353122
68398 14 starwars 256 76 48 -0.204369 -0.173445 0.113999
68399 14 starwars 256 76 49 -0.084792 -0.198950 0.259541

68400 rows × 8 columns

To represent a trajectory as a single input to a model, we can combine the x, y, and z coefficients across all bases for a single participant and clip. This yields a coefficient array of length 150 (50 bases * 3 dimensions) for each participant/clip combination. Essentially, the ‘c’ column represents a 2d array of coefficients, which is also stored in a numpy array for convenience.

num_coeff = max(coeff_df['basis'])+1

coeff_np = np.empty((0,num_coeff*3))
for traj in range(len(coeff_df)//num_coeff):
    temp_c = np.zeros(0)
    for basis in range(num_coeff):
        i = traj*num_coeff + basis
        temp_c = np.append(temp_c, [coeff_df['c_x'].iloc[i], coeff_df['c_y'].iloc[i], coeff_df['c_z'].iloc[i]])
    coeff_np = np.vstack((coeff_np,temp_c))
        
coeff_df = coeff_df[coeff_df.basis==0].drop(columns=['basis','c_y','c_z'])
clip_np = coeff_df['clip'].to_numpy()
coeff_df = coeff_df.astype(object)
for i in range(len(coeff_df)):
    coeff_df['c_x'].iloc[i] = coeff_np[i]
coeff_df = coeff_df.rename(columns={'c_x': 'c'})
coeff_df = coeff_df.reset_index(drop=True)

display(coeff_df)
clip clip_name clip_len pid c
0 0 testretest 84 1 [-0.028803697250869655, 0.019544595001130774, ...
1 0 testretest 84 2 [-0.032448210536812544, 0.005085447465171963, ...
2 0 testretest 84 3 [-0.0006059249937455069, 0.001414485757610391,...
3 0 testretest 84 4 [-0.03930034300649079, -0.015161772590474965, ...
4 0 testretest 84 5 [-0.04342440895769772, 0.0006228090080267724, ...
... ... ... ... ... ...
1363 14 starwars 256 72 [-0.06408953966162179, -0.0030453386333960685,...
1364 14 starwars 256 73 [-0.01787165594939609, -0.04387254678674551, -...
1365 14 starwars 256 74 [-0.03849370312026117, -0.005442888591800076, ...
1366 14 starwars 256 75 [-0.002775289504887294, 0.03133994037808111, -...
1367 14 starwars 256 76 [-0.023375988154500302, 0.026390365007979213, ...

1368 rows × 5 columns

Having reduced each trajectory to coefficients, we can take try classifying clips by providing a model with coefficients rather than spatial-temporal information about the trajectory itself.

MLP

We can begin with a simple MLP model with one hidden layer for classification. The model will recieve all coefficients as input. This comes in the form of 50 coefficients for each dimension, totalling to 150 coefficients.

We’ll define methods for training, evaluation, and nested cross validation using grid search.

class MLP(nn.Module):
    #def __init__(self, input_dim, hidden_dim_1, hidden_dim_2, output_dim):
    def __init__(self, input_dim, hidden_dim_1, output_dim):
        super().__init__()
        self.hidden_1 = nn.Linear(input_dim, hidden_dim_1)
        #self.hidden_2 = nn.Linear(hidden_dim_1, hidden_dim_2)
        #self.output = nn.Linear(hidden_dim_2, output_dim)
        self.output = nn.Linear(hidden_dim_1, output_dim)

    def forward(self,x):
        x = F.relu(self.hidden_1(x))
        #x = F.relu(self.hidden_2(x))
        out = self.output(x)
        return out
def train(model, criterion, optimizer, dataloader, device): 
    model.train()
    model = model.to(device)
    criterion = criterion.to(device)
    
    running_loss = 0.0
    running_corrects = 0.0

    for inputs, labels in dataloader:
        
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        _, preds = torch.max(outputs, 1)
        running_loss += loss.detach() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)
        
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.float() / len(dataloader.dataset)
    
    return epoch_loss, epoch_acc



def evaluate(model, criterion, dataloader, device):
    model.eval()
    model = model.to(device)
    criterion = criterion.to(device)
    
    running_loss = 0.0
    running_corrects = 0.0
    
    with torch.no_grad():
        for inputs, labels in dataloader:

            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
            running_loss += loss.detach() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
        
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = running_corrects.float() / len(dataloader.dataset)
    
    return epoch_loss, epoch_acc
def grid_search(model, train_set, train_dl, val_set, val_dl, criterion, num_epochs, lr_params, device):
    val_losses = np.zeros(len(lr_params))

    for i,lr in enumerate(lr_params):
#         if i==0:
#             print(f'lr = {lr} ... ', end='')
#         else:
#             print(f'                 lr = {lr} ... ', end='')
        
        optimizer = optim.Adam(model.parameters(), lr=lr)

        for epoch in range(num_epochs):
            train(model, criterion, optimizer, train_dl, device)
        val_loss, val_acc = evaluate(model, criterion, val_dl, device)

        val_losses[i] = val_loss
        
#         print('Done.')

    return val_losses
def nested_cv(model, dataset, outer_kfold, num_outer_epochs, inner_kfold, num_inner_epochs, criterion, batch_size, lr_params, device, nested):
    time_start = time.clock()
    loss_acc_df = pd.DataFrame(columns=['fold','epoch','train_loss','train_acc','test_loss','test_acc'])
#     final_loss = np.zeros(outer_kfold.n_splits)
#     final_acc = np.zeros(outer_kfold.n_splits)
    x, y = dataset[:]
    torch.save(model, 'untrained.pt')
    
    # Outer CV (trainval/test split)
    current_outer_fold = 1
    for trainval_index, test_index in outer_kfold.split(x, y):
        
        outer_model = torch.load('untrained.pt')

#         print(f'Outer Fold {current_outer_fold}/{outer_kfold.n_splits}:')
        
        x_trainval = x[trainval_index]
        y_trainval = y[trainval_index]
        x_test = x[test_index]
        y_test = y[test_index]
        trainval_data = TensorDataset(x_trainval, y_trainval)
        test_data = TensorDataset(x_test, y_test)

        
        
        # Inner CV (train/val split)
        if (nested):
            current_inner_fold = 1
            total_val_loss = np.zeros(len(lr_params))
            
            for train_index, val_index in inner_kfold.split(x_trainval, y_trainval):
                
#                 print(f'  Inner Fold {current_inner_fold}/{inner_kfold.n_splits}:')
                
                inner_model = torch.load('untrained.pt')
                
                train_data = TensorDataset(x_trainval[train_index], y_trainval[train_index])
                val_data = TensorDataset(x_trainval[val_index], y_trainval[val_index])
                train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
                val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=True)

#                 print('    Grid Search: ',end='')
                fold_val_loss = grid_search(inner_model, train_data, train_dl, val_data, val_dl, criterion, num_inner_epochs, lr_params, device)
                total_val_loss = np.add(total_val_loss, fold_val_loss)
                
                current_inner_fold += 1

            best_lr = lr_params[np.argmin(total_val_loss)]
#             print(f'  Best Learning Rate: lr = {best_lr}')
            optimizer = optim.Adam(outer_model.parameters(), lr=best_lr)

        # Non-nested CV
        else:
            optimizer = optim.Adam(outer_model.parameters())
    
    
    
        trainval_dl = DataLoader(trainval_data, batch_size=batch_size, shuffle=True)
        test_dl = DataLoader(test_data, batch_size=batch_size, shuffle=True)
        
        for epoch in range(num_outer_epochs):
            trainval_loss, trainval_acc = train(outer_model, criterion, optimizer, trainval_dl, device)
            test_loss, test_acc = evaluate(outer_model, criterion, test_dl, device)
            loss_acc_df = loss_acc_df.append({
                'fold':current_outer_fold,
                'epoch':epoch+1,
                'train_loss':trainval_loss.item(),
                'train_acc':trainval_acc.item(),
                'test_loss':test_loss.item(),
                'test_acc':test_acc.item()},
                ignore_index=True)
#             print(f'  Epoch {epoch+1:02} | Train Loss: {trainval_loss:.3f} | Train Acc: {trainval_acc*100:.2f}% | Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
#         print('')

#         final_loss[current_outer_fold-1] = test_loss
#         final_acc[current_outer_fold-1] = test_acc
        current_outer_fold += 1

#     print(f'Mean Test Loss: {np.mean(final_loss):.3f}')
#     print(f'Mean Test Acc: {np.mean(final_acc)*100:.2f}%')
#     print('Training time: %.3fs' % round(time.clock()-time_start,3))
    
    
    
    

    loss_acc_df['mean_train_loss'] = loss_acc_df.groupby('epoch')['train_loss'].transform('mean')
    loss_acc_df['std_train_loss'] = loss_acc_df.groupby('epoch')['train_loss'].transform('std')
    loss_acc_df['mean_train_acc'] = loss_acc_df.groupby('epoch')['train_acc'].transform('mean')
    loss_acc_df['std_train_acc'] = loss_acc_df.groupby('epoch')['train_acc'].transform('std')
    loss_acc_df['mean_test_loss'] = loss_acc_df.groupby('epoch')['test_loss'].transform('mean')
    loss_acc_df['std_test_loss'] = loss_acc_df.groupby('epoch')['test_loss'].transform('std')
    loss_acc_df['mean_test_acc'] = loss_acc_df.groupby('epoch')['test_acc'].transform('mean')
    loss_acc_df['std_test_acc'] = loss_acc_df.groupby('epoch')['test_acc'].transform('std')
    loss_acc_df = loss_acc_df[loss_acc_df.fold==1]
    
    fig = make_subplots(rows=2, cols=1, 
                        shared_xaxes=True,
                        vertical_spacing=0.05,
                        subplot_titles=('Loss','Accuracy'), 
                        specs=[[{'type':'scatter'}], [{'type':'scatter'}]])
    
    for dataset in ['train','test']:
        if (dataset=='train'):
            color = 'mediumblue'
        elif (dataset=='test'):
            color = 'crimson'
    
        # loss (no std)
        loss = go.Scatter(
            x=loss_acc_df['epoch'],
            y=loss_acc_df['mean_'+dataset+'_loss'],
            customdata=loss_acc_df['std_'+dataset+'_loss'],
            mode='markers+lines',
            line={'width':2, 'color':color},
            marker={'size':4, 'color':color},
            name=dataset,
            legendgroup=dataset,
            showlegend=True,
            visible=True,
            hovertemplate='epoch: %{x}<br>loss: %{y:.3f}<br>sd: %{customdata:.3f}'
        )
        fig.add_trace(loss, row=1, col=1)

        # loss (std)
        upper = loss_acc_df['mean_'+dataset+'_loss'] + loss_acc_df['std_'+dataset+'_loss']
        lower = loss_acc_df['mean_'+dataset+'_loss'] - loss_acc_df['std_'+dataset+'_loss']
        loss = go.Scatter(
            x=np.concatenate([loss_acc_df.index, loss_acc_df.index[::-1]])-loss_acc_df.index[0]+1,
            y=pd.concat([upper, lower[::-1]]),
            fill='toself',
            mode='lines',
            line={'width':0, 'color':color},
            opacity=0.7,
            name=dataset,
            legendgroup=dataset,
            showlegend=False,
            visible=True,
            hoverinfo='skip'
        )
        fig.add_trace(loss, row=1, col=1)

        # acc (no std)
        acc = go.Scatter(
            x=loss_acc_df['epoch'],
            y=loss_acc_df['mean_'+dataset+'_acc'],
            customdata=loss_acc_df['std_'+dataset+'_acc'],
            mode='markers+lines',
            line={'width':2, 'color':color},
            marker={'size':4, 'color':color},
            name=dataset,
            legendgroup=dataset,
            showlegend=False,
            visible=True,
            hovertemplate='epoch: %{x}<br>acc: %{y:.3f}<br>sd: %{customdata:.3f}'
        )
        fig.add_trace(acc, row=2, col=1)

        # acc (std)
        upper = loss_acc_df['mean_'+dataset+'_acc'] + loss_acc_df['std_'+dataset+'_acc']
        lower = loss_acc_df['mean_'+dataset+'_acc'] - loss_acc_df['std_'+dataset+'_acc']
        acc = go.Scatter(
            x=np.concatenate([loss_acc_df.index, loss_acc_df.index[::-1]])-loss_acc_df.index[0]+1,
            y=pd.concat([upper, lower[::-1]]),
            fill='toself',
            mode='lines',
            line={'width':0, 'color':color},
            opacity=0.7,
            name=dataset,
            legendgroup=dataset,
            showlegend=False,
            visible=True,
            hoverinfo='skip'
        )
        fig.add_trace(acc, row=2, col=1)

    # formatting
    fig.update_layout(
        autosize=False,
        width=800, 
        height=800, 
        margin={'l':0, 'r':0, 't':70, 'b':100},
        legend={'orientation':'h',
                'itemsizing':'constant',
                'xanchor':'center',
                'yanchor':'bottom',
                'y':-0.07,
                'x':0.5},
        title={'text':'Nested Cross Validation Loss and Accuracy',
                'xanchor':'center',
                'yanchor':'top',
                'x':0.5,
                'y':0.98},
        hovermode='x')
    fig['layout']['annotations'] += (
        {'xref':'paper',
         'yref':'paper',
         'xanchor':'center',
         'yanchor':'bottom',
         'x':0.5,
         'y':-0.14,
         'showarrow':False,
         'text':'<b>Fig. 2.</b> Mean loss and accuracy for train and test sets across outer folds of nested cross validation.<br>Error bars show the standard deviation of the mean loss and accuracy at each epoch.'
        },
    )

    plotly_config = {'displaylogo':False,
                     'modeBarButtonsToRemove': ['autoScale2d','toggleSpikelines','hoverClosestCartesian','hoverCompareCartesian','lasso2d','select2d']}

    fig.show(config=plotly_config)
# set seed for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# create dataset
coefficients = torch.from_numpy(coeff_np.astype(np.float32))
clips = torch.from_numpy(clip_np.astype(np.int))
coeff_dataset = TensorDataset(coefficients, clips)

nested_cv(
    model=MLP(150, 75, 15),
    dataset=coeff_dataset,
    outer_kfold=KFold(n_splits=5, shuffle=True, random_state=0),
    num_outer_epochs=30,
    inner_kfold=KFold(n_splits=5, shuffle=True, random_state=0),
    num_inner_epochs=20,
    criterion=nn.CrossEntropyLoss(),
    batch_size=32,
    lr_params=[0.001,0.003,0.01],
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    nested=True)

86% accuracy is extremely good for a 15-way classification, so we should be suspicious. To make sure the learned decision boundary isn’t by chance, we can try randomly shuffling the labels. This should disrupt the data’s natural class-separability and severely reduce classification accuracy.

np.random.seed(0)
yay = np.array([1,2,3,4,5,6,7,8,9])
yay = yay.astype(np.int)
yay_clone = np.copy(yay)
np.random.shuffle(yay_clone)
print(type(yay))
print(yay)
print(type(yay_clone))
print(yay_clone)
<class 'numpy.ndarray'>
[1 2 3 4 5 6 7 8 9]
<class 'numpy.ndarray'>
[8 3 2 5 9 7 4 1 6]
# set seed for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(0)

# create dataset with shuffled labels
coefficients = torch.from_numpy(coeff_np.astype(np.float32))
clip_np_copy = np.copy(clip_np)
np.random.shuffle(clip_np_copy)
clips = torch.from_numpy(clip_np_copy.astype(np.int))
coeff_dataset = TensorDataset(coefficients, clips)

nested_cv(
    model=MLP(150, 75, 15),
    dataset=coeff_dataset,
    outer_kfold=KFold(n_splits=5, shuffle=True, random_state=0),
    num_outer_epochs=30,
    inner_kfold=KFold(n_splits=5, shuffle=True, random_state=0),
    num_inner_epochs=20,
    criterion=nn.CrossEntropyLoss(),
    batch_size=32,
    lr_params=[0.001,0.003,0.01],
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    nested=True)

The accuracy drops to 22% when the labels are shuffled, which suggests that the 85% accuracy model is valid.

LSTM

RNNs do have an advantage over MLPs since they can use information from hidden states, effectively creating a memory system that allows the RNN to use information from previous inputs rather than just the current input.

This is especially relevant with trajectories since they represent a state that is changing over time. This concept is retained when using B-splines. Each basis function contributes to a part of the B-spline defined by knots, which correspond to times in the trajectory. Since the basis functions are ordered with respect to time, we can also input their coefficients in order.

a = torch.tensor([[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15],
                  [16,17,18,19,20,21,22,23,24,25,26,27,28,29,30],
                  [31,32,33,34,35,36,37,38,39,40,41,42,43,44,45],
                  [46,47,48,49,50,51,52,53,54,55,56,57,58,59,60]])
print(a.shape)
print(a)
print()
a = a.view(4,5,3)
print(a.shape)
print(a)
print()
print(a[-1])
torch.Size([4, 15])
tensor([[ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30],
        [31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45],
        [46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60]])

torch.Size([4, 5, 3])
tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9],
         [10, 11, 12],
         [13, 14, 15]],

        [[16, 17, 18],
         [19, 20, 21],
         [22, 23, 24],
         [25, 26, 27],
         [28, 29, 30]],

        [[31, 32, 33],
         [34, 35, 36],
         [37, 38, 39],
         [40, 41, 42],
         [43, 44, 45]],

        [[46, 47, 48],
         [49, 50, 51],
         [52, 53, 54],
         [55, 56, 57],
         [58, 59, 60]]])

tensor([[46, 47, 48],
        [49, 50, 51],
        [52, 53, 54],
        [55, 56, 57],
        [58, 59, 60]])
class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, dropout):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, num_layers=num_layers, dropout=dropout)
        self.fc = nn.Linear(hidden_dim, output_dim)
        
    def forward(self, x):
        batch_size = x.shape[0]
        x = x.view(batch_size,50,3) # (batch_size, seq_len, input_size)
        lstm_out, (hidden, cell) = self.lstm(x)
        return self.fc(hidden[-1])
# set seed for reproducibility
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# create dataset
coefficients = torch.from_numpy(coeff_np.astype(np.float32))
clips = torch.from_numpy(clip_np.astype(np.int))
coeff_dataset = TensorDataset(coefficients, clips)

nested_cv(
    model=LSTM(3, 10, 15, num_layers=2, dropout=0.5),
    dataset=coeff_dataset,
    outer_kfold=KFold(n_splits=5, shuffle=True, random_state=0),
    num_outer_epochs=30,
    inner_kfold=KFold(n_splits=5, shuffle=True, random_state=0),
    num_inner_epochs=15,
    criterion=nn.CrossEntropyLoss(),
    batch_size=32,
    lr_params=[0.001,0.003,0.01],
    device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'),
    nested=True)